from keypoints_Net import CoordRegression
from data_process import *
import torch.optim as optim
import dsntnn
from data_process_fashion_hw import dressDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def calculate_loss(epochs=100):
    dataloader = DataLoader(dataset=dressDataset, batch_size=8, shuffle=True, drop_last=True)
    # dataloader_val = DataLoader(dataset=blouseDataset, batch_size=2, shuffle=True)

    model = CoordRegression(n_locations=13)
    optimizer = optim.RMSprop(model.parameters(), lr=2e-4, alpha=0.85)
    # optimizer = optim.RMSprop(model.parameters(), lr=2.5e-4)

    if torch.cuda.is_available():
        #解决win10+torch支持cuda
        #https://www.jianshu.com/p/8e74b6e057ea
        model = torch.nn.DataParallel(model).cuda()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # start training
    for epoch in range(epochs):
        model.train()
        print("Epoch: {}/{}".format(epoch + 1, epochs))
        optimizer.zero_grad()
        i = 0
        for i_batch, data in enumerate(dataloader):
            img, landmarks = data
            landmarks = landmarks[:, :, 0:2]
            landmarks = ( landmarks * 2 + 1)/64.0 - 1
            img = img.to(device)
            landmarks = torch.tensor(landmarks, dtype=torch.float32)
            landmarks = landmarks.to(device)

            # 每张图像训练连续两次和三次分别保存基数和偶数模型
            # forward pass
            coords, heatmaps = model(img)
            # per-location euclidean(欧几里得) losses
            if(i_batch == 0):
                print('landmarks',landmarks)
                print('coords',coords)
            else:
                print('i_batch',i_batch)

            euc_losses = dsntnn.euclidean_losses(coords, landmarks)
            # print("predict", heatmaps.shape)
            # per-location regulation losses
            reg_losses = dsntnn.js_reg_losses(heatmaps, landmarks, sigma_t=1.0)
            # combine losses into an overall loss
            loss = dsntnn.average_loss(euc_losses + reg_losses)

            # Calculate gradients
            optimizer.zero_grad()
            loss.backward()
            train_loss = loss.data
            # Update model parameters with RMSprop
            optimizer.step()

            print(str(i_batch),
                  # ',euc_losses,', euc_losses.data,
                  # ',reg_losses,', reg_losses.data,
                  ',current_loss,{:.3f}'.format(loss.data))

        torch.save(model, r'../Models/' + 'dress_kp' + str(epoch) + '.pth')


if __name__ == "__main__":
    calculate_loss()
    print("The end!")
